Skip to content

Add TurboQuant KV cache compression (3-bit, 4.6x)#1067

Open
arozanov wants to merge 4 commits intoml-explore:mainfrom
arozanov:feature/turboquant-kv-cache
Open

Add TurboQuant KV cache compression (3-bit, 4.6x)#1067
arozanov wants to merge 4 commits intoml-explore:mainfrom
arozanov:feature/turboquant-kv-cache

Conversation

@arozanov
Copy link
Copy Markdown

Summary

Adds TurboQuant KV cache compression (arXiv 2504.19874, ICLR 2026) as a new cache type for mlx-lm.

  • 4.6x compression at 3-bit (10 values packed per uint32)
  • 0.98x FP16 speed on Qwen2.5-32B (M4 Pro 48GB)
  • Identical output quality on 32B+ models
  • Drop-in: generate_step(prompt, model, turbo_kv_bits=3)

How it works

  1. PolarQuant: Randomized Hadamard rotation → coordinates become Gaussian → optimal Lloyd-Max scalar quantization
  2. Bit-packed storage: 3-bit indices packed into uint32 (10 per word)
  3. Fused Metal kernels: quantize (norm + WHT + codebook + pack) and dequantize (unpack + codebook + WHT + scale) in single GPU dispatches
  4. Incremental decode buffer: Only dequantize new tokens per step → O(1) decode cost
  5. Layer-adaptive mode: First/last N layers in FP16 for quality on smaller models

Results

Qwen2.5-32B-Instruct-4bit (M4 Pro 48GB):

Config Compression Speed Quality
FP16 1.0x 12.4 tok/s baseline
TQ3 (all layers) 4.6x 12.1 tok/s identical

Context scaling (32B, 3-bit):

Context FP16 Cache TQ3 Cache Saved
4K 1088 MB 225 MB 863 MB
8K 2112 MB 449 MB 1.7 GB
16K 4160 MB 897 MB 3.3 GB

Qwen2.5-7B with layer-adaptive (1+1 FP16 layers):

Config Compression Speed
FP16 1.0x 54 tok/s
TQ3 adaptive 2.8x 50 tok/s (0.93x)

Usage

from mlx_lm.generate import generate_step

# Simple: add turbo_kv_bits parameter
for token, logprobs in generate_step(prompt, model, turbo_kv_bits=3):
    ...

# Or create cache manually
from mlx_lm.generate import make_turboquant_cache
cache = make_turboquant_cache(model, bits=3, fp16_layers=1)

Files added

  • mlx_lm/models/turboquant_cache.py — TurboQuantKVCache (compatible with _BaseCache)
  • mlx_lm/models/turboquant_rotation.py — Walsh-Hadamard Transform
  • mlx_lm/models/turboquant_packing.py — Bit-packing utilities
  • mlx_lm/models/turboquant_metal.py — Fused Metal quantize/dequantize kernels
  • mlx_lm/models/turboquant_kernels.py — Parallel Metal dequant kernel
  • mlx_lm/generate.py — Added turbo_kv_bits and turbo_fp16_layers parameters

Related

Test plan

  • Correctness: per-vector cosine similarity 0.996+ on real KV values
  • Pack/unpack round-trip: exact match for 2/3/4-bit
  • End-to-end: coherent generation on Qwen2.5-7B and 32B
  • Speed: 0.93-0.98x FP16 across 512-16K context lengths
  • Perplexity evaluation on standard benchmarks
  • NIAH (needle-in-a-haystack) retrieval tests

Implements TurboQuant (arXiv 2504.19874) KV cache compression:
- PolarQuant: randomized Hadamard rotation + Lloyd-Max codebook
- Bit-packed uint32 storage (3-bit: 10 values per word)
- Fused Metal kernels for quantize and dequantize
- Incremental decode buffer for O(1) per-step cost
- Layer-adaptive mode: FP16 for first/last N layers

Usage:
  generate_step(prompt, model, turbo_kv_bits=3)

Results (Qwen2.5-32B, M4 Pro 48GB):
- 4.6x compression, 0.98x FP16 speed, identical quality
- 16K context: 4.2GB → 897MB KV cache
@kipanshi
Copy link
Copy Markdown

I tried this branch on GLM-4.7-Flash-REAP-23B-A3B-mlx-nvfp4 - it outputs garbage, on main branch it works fine

@arozanov
Copy link
Copy Markdown
Author

I tried this branch on GLM-4.7-Flash-REAP-23B-A3B-mlx-nvfp4 - it outputs garbage, on main branch it works fine

That's unexpected - this branch shouldn't change default behavior, it only adds new files and optional parameters. Are you using the default generate() or did you pass turbo_kv_bits? If default, there might be a formatting issue from pre-commit that touched generate.py - I'll check.

@kipanshi
Copy link
Copy Markdown

I tried this branch on GLM-4.7-Flash-REAP-23B-A3B-mlx-nvfp4 - it outputs garbage, on main branch it works fine

That's unexpected - this branch shouldn't change default behavior, it only adds new files and optional parameters. Are you using the default generate() or did you pass turbo_kv_bits? If default, there might be a formatting issue from pre-commit that touched generate.py - I'll check.

This is the script I used:

#!/bin/bash
# Run GLM-4.7-Flash-REAP-23B-A3B with TurboQuant KV cache
# Optimized for M1 Max 32GB

MODEL_DIR="$HOME/my_docs/llms/GLM-4.7-Flash-REAP-23B-A3B-mlx-mxfp4"
MLX_LM_DIR="$HOME/opt/mlx-lm"

TURBO_KV_BITS="${TURBO_KV_BITS:-4}"       # 3-bit = 4.6x compression, 4-bit = safer quality
TURBO_FP16_LAYERS="${TURBO_FP16_LAYERS:-1}" # first/last N layers stay FP16
MAX_TOKENS="${MAX_TOKENS:-4096}"
TEMP="${TEMP:-0.7}"
TOP_P="${TOP_P:-0.9}"

PROMPT="${1:-Hello, who are you?}"

cd "$MLX_LM_DIR" || exit 1

uv run python -c "
from mlx_lm import load, stream_generate
from mlx_lm.generate import make_sampler
import sys

model, tokenizer = load('${MODEL_DIR}')

sampler = make_sampler(temp=${TEMP}, top_p=${TOP_P})

prompt = sys.argv[1]
if tokenizer.has_chat_template:
    messages = [{'role': 'user', 'content': prompt}]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True,
    )

for response in stream_generate(
    model,
    tokenizer,
    prompt=prompt,
    max_tokens=${MAX_TOKENS},
    sampler=sampler,
    turbo_kv_bits=${TURBO_KV_BITS},
    turbo_fp16_layers=${TURBO_FP16_LAYERS},
):
    print(response.text, end='', flush=True)
print()
" "$PROMPT"

@arozanov
Copy link
Copy Markdown
Author

I tried this branch on GLM-4.7-Flash-REAP-23B-A3B-mlx-nvfp4 - it outputs garbage, on main branch it works fine

That's unexpected - this branch shouldn't change default behavior, it only adds new files and optional parameters. Are you using the default generate() or did you pass turbo_kv_bits? If default, there might be a formatting issue from pre-commit that touched generate.py - I'll check.

This is the script I used:

#!/bin/bash
# Run GLM-4.7-Flash-REAP-23B-A3B with TurboQuant KV cache
# Optimized for M1 Max 32GB

MODEL_DIR="$HOME/my_docs/llms/GLM-4.7-Flash-REAP-23B-A3B-mlx-mxfp4"
MLX_LM_DIR="$HOME/opt/mlx-lm"

TURBO_KV_BITS="${TURBO_KV_BITS:-4}"       # 3-bit = 4.6x compression, 4-bit = safer quality
TURBO_FP16_LAYERS="${TURBO_FP16_LAYERS:-1}" # first/last N layers stay FP16
MAX_TOKENS="${MAX_TOKENS:-4096}"
TEMP="${TEMP:-0.7}"
TOP_P="${TOP_P:-0.9}"

PROMPT="${1:-Hello, who are you?}"

cd "$MLX_LM_DIR" || exit 1

uv run python -c "
from mlx_lm import load, stream_generate
from mlx_lm.generate import make_sampler
import sys

model, tokenizer = load('${MODEL_DIR}')

sampler = make_sampler(temp=${TEMP}, top_p=${TOP_P})

prompt = sys.argv[1]
if tokenizer.has_chat_template:
    messages = [{'role': 'user', 'content': prompt}]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True,
    )

for response in stream_generate(
    model,
    tokenizer,
    prompt=prompt,
    max_tokens=${MAX_TOKENS},
    sampler=sampler,
    turbo_kv_bits=${TURBO_KV_BITS},
    turbo_fp16_layers=${TURBO_FP16_LAYERS},
):
    print(response.text, end='', flush=True)
print()
" "$PROMPT"

Ah got it, you are using turbo_kv_bits=4. That's expected to have quality issues on the K tensor - I've seen the same thing. Try turbo_kv_bits=3 which actually works better (counterintuitively, the 3-bit codebook fits the post-rotation Gaussian distribution better than 4-bit for K. Also for a 23B model try increasing turbo_fp16_layers=2 or turbo_fp16_layers=4 to keep more layers in full precision.

@arozanov
Copy link
Copy Markdown
Author

I tried this branch on GLM-4.7-Flash-REAP-23B-A3B-mlx-nvfp4 - it outputs garbage, on main branch it works fine

That's unexpected - this branch shouldn't change default behavior, it only adds new files and optional parameters. Are you using the default generate() or did you pass turbo_kv_bits? If default, there might be a formatting issue from pre-commit that touched generate.py - I'll check.

This is the script I used:

#!/bin/bash
# Run GLM-4.7-Flash-REAP-23B-A3B with TurboQuant KV cache
# Optimized for M1 Max 32GB

MODEL_DIR="$HOME/my_docs/llms/GLM-4.7-Flash-REAP-23B-A3B-mlx-mxfp4"
MLX_LM_DIR="$HOME/opt/mlx-lm"

TURBO_KV_BITS="${TURBO_KV_BITS:-4}"       # 3-bit = 4.6x compression, 4-bit = safer quality
TURBO_FP16_LAYERS="${TURBO_FP16_LAYERS:-1}" # first/last N layers stay FP16
MAX_TOKENS="${MAX_TOKENS:-4096}"
TEMP="${TEMP:-0.7}"
TOP_P="${TOP_P:-0.9}"

PROMPT="${1:-Hello, who are you?}"

cd "$MLX_LM_DIR" || exit 1

uv run python -c "
from mlx_lm import load, stream_generate
from mlx_lm.generate import make_sampler
import sys

model, tokenizer = load('${MODEL_DIR}')

sampler = make_sampler(temp=${TEMP}, top_p=${TOP_P})

prompt = sys.argv[1]
if tokenizer.has_chat_template:
    messages = [{'role': 'user', 'content': prompt}]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True,
    )

for response in stream_generate(
    model,
    tokenizer,
    prompt=prompt,
    max_tokens=${MAX_TOKENS},
    sampler=sampler,
    turbo_kv_bits=${TURBO_KV_BITS},
    turbo_fp16_layers=${TURBO_FP16_LAYERS},
):
    print(response.text, end='', flush=True)
print()
" "$PROMPT"

Found it. Your config turbo_kv_bits=4, turbo_fp16_layers=1 should work on most models, but MoE architectures like GLM-4.7-Flash might need more FP16 layers. Try turbo_fp16_layers=4 or turbo_fp16_layers=6. On my 7B tests, both 3-bit and 4-bit with fp16_layers=1 produce clean output.

@kipanshi
Copy link
Copy Markdown

Ah got it, you are using turbo_kv_bits=4. That's expected to have quality issues on the K tensor - I've seen the same thing. Try turbo_kv_bits=3 which actually works better (counterintuitively, the 3-bit codebook fits the post-rotation Gaussian distribution better than 4-bit for K. Also for a 23B model try increasing turbo_fp16_layers=2 or turbo_fp16_layers=4 to keep more layers in full precision.

whith params you suggested same garbage issue:
"""
Without turboquant the branch works fine — model outputs correctly. So the turboquant cache itself is incompatible with the glm4_moe_lite MLA
architecture. The MLA stores compressed latents (kv_lora_rank=512, qk_rope_head_dim=64) in the cache, not standard key/value tensors — and
turboquant's PolarQuant rotation likely can't handle that compressed representation correctly.
"""
I will try to test it with Qwen3.5 35B MoE

@arozanov
Copy link
Copy Markdown
Author

Ah got it, you are using turbo_kv_bits=4. That's expected to have quality issues on the K tensor - I've seen the same thing. Try turbo_kv_bits=3 which actually works better (counterintuitively, the 3-bit codebook fits the post-rotation Gaussian distribution better than 4-bit for K. Also for a 23B model try increasing turbo_fp16_layers=2 or turbo_fp16_layers=4 to keep more layers in full precision.

whith params you suggested same garbage issue: """ Without turboquant the branch works fine — model outputs correctly. So the turboquant cache itself is incompatible with the glm4_moe_lite MLA architecture. The MLA stores compressed latents (kv_lora_rank=512, qk_rope_head_dim=64) in the cache, not standard key/value tensors — and turboquant's PolarQuant rotation likely can't handle that compressed representation correctly. """ I will try to test it with Qwen3.5 35B MoE

Yeah MLA is a different beast, makes sense it breaks. Good catch. Qwen3.5 should be fine since it's standard attention. Let me know how it goes.

@kipanshi
Copy link
Copy Markdown

whith params you suggested same garbage issue: """ Without turboquant the branch works fine — model outputs correctly. So the turboquant cache itself is incompatible with the glm4_moe_lite MLA architecture. The MLA stores compressed latents (kv_lora_rank=512, qk_rope_head_dim=64) in the cache, not standard key/value tensors — and turboquant's PolarQuant rotation likely can't handle that compressed representation correctly. """ I will try to test it with Qwen3.5 35B MoE

Yeah MLA is a different beast, makes sense it breaks. Good catch. Qwen3.5 should be fine since it's standard attention. Let me know how it goes.

Did more testing:

  • GLM-4.7-Flash (MLA): loads but produces garbage (MLA latent cache incompatible)
  • Qwen3.5-35B-A3B (hybrid SSM/attention): crashes (SSM cache not supported)
  • Standard attention models (Llama, Mistral): works correctly

@arozanov
Copy link
Copy Markdown
Author

whith params you suggested same garbage issue: """ Without turboquant the branch works fine — model outputs correctly. So the turboquant cache itself is incompatible with the glm4_moe_lite MLA architecture. The MLA stores compressed latents (kv_lora_rank=512, qk_rope_head_dim=64) in the cache, not standard key/value tensors — and turboquant's PolarQuant rotation likely can't handle that compressed representation correctly. """ I will try to test it with Qwen3.5 35B MoE

Yeah MLA is a different beast, makes sense it breaks. Good catch. Qwen3.5 should be fine since it's standard attention. Let me know how it goes.

Did more testing:

  • GLM-4.7-Flash (MLA): loads but produces garbage (MLA latent cache incompatible)
  • Qwen3.5-35B-A3B (hybrid SSM/attention): crashes (SSM cache not supported)
  • Standard attention models (Llama, Mistral): works correctly

Thanks for testing across architectures. MLA and SSM are expected - TurboQuant only works with standard multi-head attention KV cache. I should add a check that raises a clear error instead of silently producing garbage. Will fix.

@babhishek21
Copy link
Copy Markdown

some thoughts for DX:

  1. Since this is a compression scheme, perhaps it should be given the same treatment as KVCache#to_quantized(). The basic unbounded KVCache could have a to_turbo_quantized() (or equivalent) that returns a TurboQuantKVCache.
  2. Possibility to have a generalized convert_to_turbo_quantized() function (similar to how [Experimental] Add TurboQuantKVCache: PolarQuant KV cache compression at 2-4 bits #1059 does it), with supporting cache specializations progressively adopting to_turbo_quantized() (where supported).
  3. make_prompt_cache should still be the entry point for feature enablement related to caches; in this particular case whether to enable TurboQuant compression or not. Similar to how max_kv_size causes a switch to bounded cache, params turboq_kv_bits and turboq_fp16_layers could switch on TurboQuant. Would also help dedupe all the logic around if hasattr(model, "make_cache"): return model.make_cache().
  4. Lib users are still able to pass in any custom prompt_cache to generate.
  5. CLI users should be able to pass in TurboQuant args in the same way they pass in --max-kv-size.

@arozanov
Copy link
Copy Markdown
Author

some thoughts for DX:

  1. Since this is a compression scheme, perhaps it should be given the same treatment as KVCache#to_quantized(). The basic unbounded KVCache could have a to_turbo_quantized() (or equivalent) that returns a TurboQuantKVCache.
  2. Possibility to have a generalized convert_to_turbo_quantized() function (similar to how [Experimental] Add TurboQuantKVCache: PolarQuant KV cache compression at 2-4 bits #1059 does it), with supporting cache specializations progressively adopting to_turbo_quantized() (where supported).
  3. make_prompt_cache should still be the entry point for feature enablement related to caches; in this particular case whether to enable TurboQuant compression or not. Similar to how max_kv_size causes a switch to bounded cache, params turboq_kv_bits and turboq_fp16_layers could switch on TurboQuant. Would also help dedupe all the logic around if hasattr(model, "make_cache"): return model.make_cache().
  4. Lib users are still able to pass in any custom prompt_cache to generate.
  5. CLI users should be able to pass in TurboQuant args in the same way they pass in --max-kv-size.

Good points, agree on all of them. Specifically:

  1. to_turbo_quantized() on KVCache - makes sense, will add
  2. Routing through make_prompt_cache instead of separate function - cleaner, agreed
  3. CLI args --turbo-kv-bits and --turbo-fp16-layers alongside --max-kv-size - will do
    I'll rework the PR to follow the existing patterns. Thanks for the detailed review.

@babhishek21
Copy link
Copy Markdown

@arozanov I think you'll need to add tests.
@awni @andresy with that, I think this PR will probably supersede #1059

arozanov pushed a commit to arozanov/vllm-mlx that referenced this pull request Mar 29, 2026
Adds --turbo-kv-bits flag (1-4) to compress stored prefix cache entries
using TurboQuant (arXiv 2504.19874). 3-bit gives 4.6x compression vs FP16,
compared to ~2x from the existing 8-bit quantization.

Integration points:
- memory_cache.py: _turbo_quantize_cache/_dequantize_cache, memory estimation,
  trim support, needs_dequantize property, config validation
- scheduler.py: turbo_kv_bits in SchedulerConfig, propagation to MemoryCacheConfig
- cli.py: --turbo-kv-bits for serve and bench commands

Requires mlx-lm with TurboQuant support (ml-explore/mlx-lm#1067).
arozanov added a commit to arozanov/vllm-mlx that referenced this pull request Mar 29, 2026
Adds --turbo-kv-bits flag (1-4) to compress stored prefix cache entries
using TurboQuant (arXiv 2504.19874). 3-bit gives 4.6x compression vs FP16,
compared to ~2x from the existing 8-bit quantization.

Integration points:
- memory_cache.py: _turbo_quantize_cache/_dequantize_cache, memory estimation,
  trim support, needs_dequantize property, config validation
- scheduler.py: turbo_kv_bits in SchedulerConfig, propagation to MemoryCacheConfig
- cli.py: --turbo-kv-bits for serve and bench commands

Requires mlx-lm with TurboQuant support (ml-explore/mlx-lm#1067).
@arozanov arozanov force-pushed the feature/turboquant-kv-cache branch from a087778 to 9315fbc Compare March 29, 2026 16:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants